# -*- coding: utf-8 -*-
"""
DPO 继续训练：以“专家DPO LoRA”为初始权重，用公众偏好数据微调 -> 融合模型
"""

import os
os.environ.setdefault("PYTORCH_SDP_DISABLE_FLASH_ATTENTION", "1")
if "PYTORCH_SDP_DISABLE_MEM_EFFICIENT" in os.environ:
    del os.environ["PYTORCH_SDP_DISABLE_MEM_EFFICIENT"]

import argparse
import torch
from typing import Dict, Any
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import DPOTrainer
from trl.trainer.dpo_config import DPOConfig
from peft import PeftModel  # ✨ 关键：从已训好的LoRA加载

def get_args():
    p = argparse.ArgumentParser()
    # 仍然指向 base 模型目录
    p.add_argument("--model_path", type=str, default=r"D:\chatglm3-models\chatglm3-6b")
    # 公众偏好数据
    p.add_argument("--data_jsonl", type=str, default=r"D:\chongxinxuexi\pythonProject1\data\public_pairs.jsonl")
    # 输出到新的“融合模型”目录，避免覆盖专家权重
    p.add_argument("--output_dir", type=str, default=r"D:\chongxinxuexi\pythonProject1\dpo_model\fusion_from_expert")
    # ✨ 新增：专家LoRA适配器路径（就是你截图里的 expert 目录）
    p.add_argument("--init_adapter", type=str, default=r"D:\chongxinxuexi\pythonProject1\dpo_model\expert")

    p.add_argument("--epochs", type=int, default=1)
    p.add_argument("--lr", type=float, default=1e-5)
    p.add_argument("--beta", type=float, default=0.3)
    p.add_argument("--batch", type=int, default=1)
    p.add_argument("--accum", type=int, default=1)
    p.add_argument("--max_length", type=int, default=700)
    p.add_argument("--max_prompt_length", type=int, default=256)
    p.add_argument("--save_each_epoch", action="store_true")
    p.add_argument("--max_steps", type=int, default=-1)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--num_workers", type=int, default=2)
    p.add_argument("--pin_memory", action="store_true")
    return p.parse_args()

def ensure_keys(example: Dict[str, Any]):
    need = {"prompt", "chosen", "rejected"}
    for k in need:
        if k not in example:
            raise ValueError(f"样本缺少字段 {k}，需要包含 {need}")

def print_step(msg: str):
    print(f"[INFO] {msg}", flush=True)

def main():
    args = get_args()
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.backends.cuda.matmul.allow_tf32 = True
        try:
            torch.set_float32_matmul_precision("high")
        except Exception:
            pass
        try:
            from torch.backends.cuda import sdp_kernel
            sdp_kernel.enable_flash(False)
            sdp_kernel.enable_mem_efficient(True)
            sdp_kernel.enable_math(False)
        except Exception:
            pass

    # ---- Tokenizer
    print_step("加载 Tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    tokenizer.padding_side = "right"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # ---- Base Model
    print_step("加载 Base 模型")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map={"": 0},
    )
    model.gradient_checkpointing_enable()
    model.config.use_cache = False
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.eos_token_id = tokenizer.eos_token_id

    # ---- ✨ 从专家LoRA继续训练（关键3行）
    print_step(f"加载专家 LoRA 作为初始权重：{args.init_adapter}")
    model = PeftModel.from_pretrained(model, args.init_adapter, is_trainable=True)
    # 到这里，模型已经是 PeftModel，且 LoRA 权重可训练；不要再传新的 peft_config

    # ---- Dataset
    print_step("加载公众偏好数据")
    ds = load_dataset("json", data_files=args.data_jsonl)["train"]
    if len(ds) == 0:
        raise RuntimeError("数据集为空，请检查 --data_jsonl")
    ensure_keys(ds[0])
    train_dataset = ds

    # ---- DPO 训练配置
    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    training_args = DPOConfig(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch,
        gradient_accumulation_steps=args.accum,
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        beta=args.beta,
        max_length=args.max_length,
        max_prompt_length=args.max_prompt_length,
        remove_unused_columns=False,
        logging_steps=10,
        eval_strategy="no",
        save_strategy="epoch" if args.save_each_epoch else "no",
        fp16=(not use_bf16),
        bf16=use_bf16,
        report_to=["tensorboard"],
        logging_dir=os.path.join(args.output_dir, "tb"),
        dataloader_num_workers=args.num_workers,
        dataloader_pin_memory=args.pin_memory,
        max_grad_norm=1.0,
    )

    # ---- Trainer（不再传 peft_config）
    trainer = DPOTrainer(
        model=model,
        ref_model=None,
        args=training_args,
        processing_class=tokenizer,
        train_dataset=train_dataset,
    )

    print_step("开始训练（公众数据，继续在专家LoRA上优化）")
    if args.max_steps and args.max_steps > 0:
        trainer.train(max_steps=args.max_steps)
    else:
        trainer.train()

    print_step("保存融合模型")
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    trainer.save_state()
    print_step(f"✅ 完成，融合模型保存在：{args.output_dir}")

if __name__ == "__main__":
    main()
